/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com> 
 *
 */
#include "IPProtocol.h"
#include <iomanip> // setw

namespace aiengine {

IPProtocol::IPProtocol(const std::string &name):
	Protocol(name) {}

IPProtocol::~IPProtocol() {

	anomaly_.reset();
}

bool IPProtocol::check(const Packet &packet) {

	int length = packet.getLength();

	if (length >= header_size) {
		setHeader(packet.getPayload());
		if (isIPver4()) {
			++total_valid_packets_;
			return true;
		}
	}	
	++total_invalid_packets_;
	return false;
}

bool IPProtocol::isFragment() const { 

	return (ntohs(header_->ip_off) & (IP_MF|IP_OFFMASK));
}

bool IPProtocol::processPacket(Packet &packet) {

	CPUCycle cycles(&total_cpu_cycles_);
        MultiplexerPtr mux = mux_.lock();
	int bytes = 0;

	++total_packets_;

	mux->address.setSourceAddress(getSrcAddr());
	mux->address.setDestinationAddress(getDstAddr());

	// Some packets have padding data at the end
	if (getPacketLength() < packet.getLength())
		bytes = getPacketLength();
	else
		bytes = packet.getLength();

	mux->total_length = bytes;
	total_bytes_ += bytes;

	packet.net_packet.setPayload(packet.getPayload());
        packet.net_packet.setLength(bytes);
	
	mux->setNextProtocolIdentifier(getProtocol());
	packet.setPrevHeaderSize(header_size);

	packet.setTTL(getTTL());

#ifdef DEBUG
	std::cout << __FILE__ << ":" << __func__ << ": ip.src(" << getSrcAddrDotNotation() << ")ip.dst(" << getDstAddrDotNotation() << ")ip.id(" << getID() << ")" ;
	std::cout << "ip.hdrlen(" << getIPHeaderLength() << ")ip.len(" << getPacketLength() << ")ip.ttl(" << (int)getTTL() << ")" << std::endl;
#endif

	if (isFragment() == true) {
		++total_events_;
		++total_frag_packets_;
		packet.setPacketAnomaly(PacketAnomalyType::IPV4_FRAGMENTATION);
                anomaly_->incAnomaly(PacketAnomalyType::IPV4_FRAGMENTATION);

		// If the fragment offset is zero this means that is there
		// is data that can be process by upper levels
		if ((ntohs(header_->ip_off) & IP_OFFMASK) > 0)
			return false;
	}
	return true;
}

void IPProtocol::processFlow(Flow *flow) {

	// TODO: Encapsulations such as ip over ip	
}

void IPProtocol::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

	if (level > 3) 
		out << "\t" << "Total fragment packets: " << std::setw(10) << total_frag_packets_ << std::endl;
	
	if ((level > 5)and(mux_.lock()))
		mux_.lock()->statistics(out);
}

void IPProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) 
                out["fragment_packets"] = total_frag_packets_;
}

CounterMap IPProtocol::getCounters() const {
	CounterMap cm;

	cm.addKeyValue("packets", total_packets_);
	cm.addKeyValue("bytes", total_bytes_);
	cm.addKeyValue("fragmented packets", total_frag_packets_);

       	return cm;
}

void IPProtocol::resetCounters() {

	reset();

        total_frag_packets_ = 0;
        total_events_ = 0;
}

} // namespace aiengine
